import os

from PIL import Image


def get_file_paths(folder):
    image_file_paths = []
    for root, dirs, filenames in os.walk(folder):
        filenames = sorted(filenames)
        for filename in filenames:
            input_path = os.path.abspath(root)
            file_path = os.path.join(input_path, filename)
            if filename.endswith('.png') or filename.endswith('.jpg'):
                image_file_paths.append(file_path)

        break  # prevent descending into subfolders
    return image_file_paths


def make_images(file_paths, label_paths, target_path, max_num=1000):
    ta_path = target_path + 'A'
    tb_path = target_path + 'B'
    trainA_path = os.path.join(target_path, 'trainA')
    trainB_path = os.path.join(target_path, 'trainB')
    testA_path = os.path.join(target_path, 'testA')
    testB_path = os.path.join(target_path, 'testB')
    if not os.path.exists(trainA_path):
        os.makedirs(trainA_path)
    if not os.path.exists(trainB_path):
        os.makedirs(trainB_path)
    if not os.path.exists(testA_path):
        os.makedirs(testA_path)
    if not os.path.exists(testB_path):
        os.makedirs(testB_path)


    for i in range(len(file_paths)):
        img = Image.open(file_paths[i]).convert('RGB').resize((256,256))
        label_img = Image.open(label_paths[i]).convert('RGB').resize((256,256))
        # split AB image into A and B
        assert(img.size == label_img.size)
        basename = os.path.basename(file_paths[i].split('.')[0])
        if i >= len(file_paths)-500:
            rootA_path, rootB_path = testA_path, testB_path
        else:
            rootA_path, rootB_path = trainA_path, trainB_path
        img.save(os.path.join(rootA_path, '%s.jpg'%(basename)), format='JPEG', subsampling=0, quality=100)
        label_img.save(os.path.join(rootB_path, '%s.jpg'%(basename)), format='JPEG', subsampling=0, quality=100)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--img_path',
        help='Which folder to process (it should have subfolders test, train)'
    )
    parser.add_argument(
        '--label_path',
        help='Which folder to process (it should have subfolders test, train)')
    args = parser.parse_args()

    dataset_folder = args.img_path
    label_folder = args.label_path
    file_paths = get_file_paths(dataset_folder)
    label_paths = get_file_paths(label_folder)
    make_images(file_paths, label_paths, 'GTA5_I2I')

